import math
from math import sqrt
import argparse
from pathlib import Path
from unittest import TestCase

# torch

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

# vision imports

from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle classes and utils

from dalle_pytorch import distributed_utils
# from dalle_pytorch import DiscreteVAE
from dalle_pytorch.dalle_pytorch_ori import DiscreteVAE
# from dalle_pytorch.dalle_pytorch_oriema import DiscreteVAE
# from dalle_pytorch.dalle_pytorch_ae import DiscreteVAE

# argument parsing

import sys
sys.path.insert(0, '/home/tiangel/ShapeGF')
try:
    from evaluation.evaluation_metrics import EMD_CD
    eval_reconstruciton = True
except:  # noqa
    # Skip evaluation
    eval_reconstruciton = False
sys.path.insert(0, '/home/tiangel/PVD')
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *

# sys.path.insert(0, '/home/tiangel/Learning-to-Group')
from IPython import embed
import glob
# from pytorch3d.io import load_ply
from pytorch3d.io import load_ply, save_ply
from torch.utils.data import Dataset
import os
# from partnet.utils.torch_pc import normalize_points as normalize_points_torch

from pytorch3d.io import IO
from pytorch3d.structures import Pointclouds
# from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVOrthographicCameras, 
    PointsRasterizationSettings,
    PointsRenderer,
    PulsarPointsRenderer,
    PointsRasterizer,
    AlphaCompositor,
    NormWeightedCompositor
)
import matplotlib.pyplot as plt
import numpy as np
from geometry_utils import render_pts, rotate_pts, render_pts_with_label
import h5py
from scipy.optimize import linear_sum_assignment
from pytorch3d.loss import chamfer_distance




def normalize_points_torch(points):
    """Normalize point cloud

    Args:
        points (torch.Tensor): (batch_size, num_points, 3)

    Returns:
        torch.Tensor: normalized points

    """
    assert points.dim() == 3 and points.size(2) == 3
    centroid = points.mean(dim=1, keepdim=True)
    points = points - centroid
    norm, _ = points.norm(dim=2, keepdim=True).max(dim=1, keepdim=True)
    new_points = points / norm
    return new_points

parser = argparse.ArgumentParser()

parser.add_argument('--image_folder', type = str, required = True,
                    help='path to your folder of images for learning the discrete VAE and its codebook')

parser.add_argument('--image_size', type = int, required = False, default = 128,
                    help='image size')

parser = distributed_utils.wrap_arg_parser(parser)


train_group = parser.add_argument_group('Training settings')

train_group.add_argument('--vae_path', type=str,
                   help='path to your trained discrete VAE')

train_group.add_argument('--epochs', type = int, default = 20, help = 'number of epochs')

train_group.add_argument('--batch_size', type = int, default = 8, help = 'batch size')

train_group.add_argument('--learning_rate', type = float, default = 1e-3, help = 'learning rate')

train_group.add_argument('--lr_decay_rate', type = float, default = 0.98, help = 'learning rate decay')

train_group.add_argument('--starting_temp', type = float, default = 1., help = 'starting temperature')

train_group.add_argument('--temp_min', type = float, default = 0.5, help = 'minimum temperature to anneal to')

train_group.add_argument('--anneal_rate', type = float, default = 1e-6, help = 'temperature annealing rate')

train_group.add_argument('--num_images_save', type = int, default = 2, help = 'number of images to save')

model_group = parser.add_argument_group('Model settings')

model_group.add_argument('--num_tokens', type = int, default = 8192, help = 'number of image tokens')

model_group.add_argument('--num_layers', type = int, default = 3, help = 'number of layers (should be 3 or above)')

model_group.add_argument('--num_resnet_blocks', type = int, default = 2, help = 'number of residual net blocks')

model_group.add_argument('--smooth_l1_loss', dest = 'smooth_l1_loss', action = 'store_true')

model_group.add_argument('--emb_dim', type = int, default = 512, help = 'embedding dimension')

model_group.add_argument('--hidden_dim', type = int, default = 256, help = 'hidden dimension')

model_group.add_argument('--dim1', type = int, default = 16, help = 'hidden dimension')

model_group.add_argument('--save_img_num', type = int, default = 20, help = 'hidden dimension')

model_group.add_argument('--dim2', type = int, default = 32, help = 'hidden dimension')

model_group.add_argument('--final_points', type = int, default = 16, help = 'hidden dimension')

model_group.add_argument('--radius', type = float, default = 0.3, help = 'hidden dimension')

model_group.add_argument('--kl_loss_weight', type = float, default = 0., help = 'KL loss weight')

model_group.add_argument('--save_name', type = str, default = '1', help = 'KL loss weight')

model_group.add_argument('--aug', type = bool, default = True, help = 'KL loss weight')

model_group.add_argument('--emd', action='store_true', help = 'emd or not')

model_group.add_argument('--testae', type = bool, default = False, help = 'KL loss weight')

model_group.add_argument('--category', type = str, default = 'car', help = 'KL loss weight')


args = parser.parse_args()

# constants

IMAGE_SIZE = args.image_size
IMAGE_PATH = args.image_folder

EPOCHS = args.epochs
BATCH_SIZE = args.batch_size
LEARNING_RATE = args.learning_rate
LR_DECAY_RATE = args.lr_decay_rate

NUM_TOKENS = args.num_tokens
NUM_LAYERS = args.num_layers
NUM_RESNET_BLOCKS = args.num_resnet_blocks
SMOOTH_L1_LOSS = args.smooth_l1_loss
EMB_DIM = args.emb_dim
HIDDEN_DIM = args.hidden_dim
KL_LOSS_WEIGHT = args.kl_loss_weight

STARTING_TEMP = args.starting_temp
TEMP_MIN = args.temp_min
ANNEAL_RATE = args.anneal_rate

NUM_IMAGES_SAVE = args.num_images_save

# initialize distributed backend

distr_backend = distributed_utils.set_backend_from_args(args)
distr_backend.initialize()

using_deepspeed = \
    distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)

# data

class PC_Dataset(Dataset):
    def __init__(self, path):
        self.data_dir = path
        self.data_list = glob.glob(os.path.join('/home/tiangel/datasets', self.data_dir, '*.ply'))
        self.len = len(self.data_list)
        self.do_aug = args.aug

    def __getitem__(self, index):
        pc = load_ply(self.data_list[index])
        points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
        if self.do_aug:
            scale = points.new(1).uniform_(0.9, 1.05)
            points[:, 0:3] *= scale
        return (points, pc[1])

    def __len__(self):
        return self.len

# ds = PC_Dataset(IMAGE_PATH)
class PC_Dataset_h5(Dataset):
    def __init__(self, path):
        f = h5py.File(os.path.join('/home/tiangel/datasets/',path), 'r')
        self.data = np.array(f['data'])
        self.len = self.data.shape[0]
        self.do_aug = args.aug

    def __getitem__(self, index):
        # pc = load_ply(self.data_list[index])
        pc = torch.Tensor(self.data[index]).unsqueeze(0)
        points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
        if self.do_aug:
            scale = points.new(1).uniform_(0.9, 1.05)
            points[:, 0:3] *= scale
        return points

    def __len__(self):
        return self.len

def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
    tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
        categories=[category], split='train',
        tr_sample_size=npoints,
        te_sample_size=npoints,
        scale=1.,
        normalize_per_shape=False,
        normalize_std_per_axis=False,
        random_subsample=True)
    te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
                                            # cache=os.path.join(pc_dataroot, '../cache'), split='train',
                                            cache=os.path.join(pc_dataroot, '../cache'), split='val',
        categories=[category],
        npoints=npoints, sv_samples=200,
        all_points_mean=tr_dataset.all_points_mean,
        all_points_std=tr_dataset.all_points_std,
    )
    return te_dataset

ds = get_mvr_dataset('/home/tiangel/datasets/ShapeNetCore.v2.PC15k', '/home/tiangel/datasets/GenReData/', 
                        # 10000, args.category)
                        2048, args.category)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
# ds = ImageFolder(
#     IMAGE_PATH,
#     T.Compose([
#         T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
#         T.Resize(IMAGE_SIZE),
#         T.CenterCrop(IMAGE_SIZE),
#         T.ToTensor()
#     ])
# )

if distributed_utils.using_backend(distributed_utils.HorovodBackend):
    data_sampler = torch.utils.data.distributed.DistributedSampler(
        ds, num_replicas=distr_backend.get_world_size(),
        rank=distr_backend.get_rank())
else:
    data_sampler = None

dl = DataLoader(ds, BATCH_SIZE, shuffle = False, drop_last=False)

#### for collecting training data
# target_list = []
# input_list = []
# for data in dl:
#     # gt_all = data['test_points']
#     # x_all = data['sv_points']
#     target_list.append(data['test_points'])
#     input_list.append(data['sv_points'])
# targets = torch.cat(target_list, dim=0)
# inputs = torch.cat(input_list, dim=0)
# import h5py
# outputs = h5py.File('/home/tiangel/datasets/completion_'+args.category+'_2048_val.h5', 'w')
# # outputs = h5py.File('/home/tiangel/datasets/completion_'+args.category+'_2048.h5', 'w')
# # outputs = h5py.File('/home/tiangel/datasets/completion_'+args.category+'.h5', 'w')
# outputs['inputs'] = inputs.numpy()
# outputs['targets'] = targets.numpy()
# outputs.close()

# embed()
# exit()

loaded_obj = torch.load(os.path.join('./outputs/vae_models',args.vae_path))
vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

vae = DiscreteVAE(
    **vae_params,
)

# keys = list(weights.keys())
# for k in keys:
#     weights['.'.join(k.split('.')[1:])] = weights[k]
#     weights.pop(k)


vae.load_state_dict(weights)
vae.eval().cuda()

# vae_params = dict(
#     image_size = IMAGE_SIZE,
#     num_layers = NUM_LAYERS,
#     num_tokens = NUM_TOKENS,
#     codebook_dim = EMB_DIM,
#     hidden_dim   = HIDDEN_DIM,
#     num_resnet_blocks = NUM_RESNET_BLOCKS,
#     dim1 = args.dim1,
#     dim2 = args.dim2,
#     radius = args.radius
# )

# vae = DiscreteVAE(
#     **vae_params,
#     smooth_l1_loss = SMOOTH_L1_LOSS,
#     kl_div_loss_weight = KL_LOSS_WEIGHT
# )
# if not using_deepspeed:
#     vae = vae.cuda()


assert len(ds) > 0, 'folder does not contain any images'
if distr_backend.is_root_worker():
    print(f'{len(ds)} images found for training')

# optimizer

opt = Adam(vae.parameters(), lr = LEARNING_RATE)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = opt, T_max = EPOCHS*int(len(ds)/BATCH_SIZE))
# sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE)


if distr_backend.is_root_worker():
    # weights & biases experiment tracking

    import wandb

    model_config = dict(
        num_tokens = NUM_TOKENS,
        smooth_l1_loss = SMOOTH_L1_LOSS,
        num_resnet_blocks = NUM_RESNET_BLOCKS,
        kl_loss_weight = KL_LOSS_WEIGHT
    )

    run = wandb.init(
        project = 'dalle_train_vae',
        job_type = 'train_model',
        config = model_config
    )

# distribute

distr_backend.check_batch_size(BATCH_SIZE)
deepspeed_config = {'train_batch_size': BATCH_SIZE}

(distr_vae, distr_opt, distr_dl, distr_sched) = distr_backend.distribute(
    args=args,
    model=vae,
    optimizer=opt,
    model_parameters=vae.parameters(),
    training_data=ds if using_deepspeed else dl,
    lr_scheduler=sched if not using_deepspeed else None,
    config_params=deepspeed_config,
)

using_deepspeed_sched = False
# Prefer scheduler in `deepspeed_config`.
if distr_sched is None:
    distr_sched = sched
elif using_deepspeed:
    # We are using a DeepSpeed LR scheduler and want to let DeepSpeed
    # handle its scheduling.
    using_deepspeed_sched = True


def save_model(path):
    save_obj = {
        'hparams': vae_params,
    }
    if using_deepspeed:
        cp_path = Path(path)
        path_sans_extension = cp_path.parent / cp_path.stem
        cp_dir = str(path_sans_extension) + '-ds-cp'

        distr_vae.save_checkpoint(cp_dir, client_state=save_obj)
        # We do not return so we do get a "normal" checkpoint to refer to.

    if not distr_backend.is_root_worker():
        return

    save_obj = {
        **save_obj,
        'weights': vae.state_dict()
    }

    torch.save(save_obj, path)

# starting temperature

def render_pytorch3d(renderer, pts, count, name):
    rgb=torch.zeros(pts.shape).cuda()
    rgb[:,1]=0.5
    point_cloud = Pointclouds(points=[pts], features=[rgb])


    rendered_img = renderer(point_cloud, gamma=(1e-4,))
    rendered_img[rendered_img == 0] = 1
    plt.figure(figsize=(10, 10))
    plt.imshow(rendered_img[0, ..., :3].detach().cpu().numpy())
    plt.axis("off");
    plt.savefig(os.path.join(save_dir, '%04d'%count+'_'+name+'.png'),dpi=300)
    plt.close()

global_step = 0
temp = STARTING_TEMP
save_dir = os.path.join('./outputs/vae_outputs','test'+args.save_name)
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
count = 0
cd_loss_list = []
emd_loss_list = []
vq_loss_list = []
perplexity_list = []

R, T = look_at_view_transform(50, 10, 45)
cameras = FoVOrthographicCameras(R=R, T=T, znear=0.01).cuda()
raster_settings = PointsRasterizationSettings(
    image_size=(256,256), 
    radius = 0.005,
    points_per_pixel = 1
)
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
renderer = PulsarPointsRenderer(
    rasterizer=rasterizer,
).cuda()
# for i, (image, _)s, _) in enumerate(distr_dl):
ori_pts = []
our_pts = []
with torch.no_grad():
    for i, data in enumerate(dl):
        print(len(dl), i)
        gt_all = data['test_points']
        x_all = data['sv_points']

        cd_loss, emd_loss, vq_loss, recons, perplexity = vae(
            # x_all.reshape(-1, 200, 3).repeat([1, 3, 1]).cuda(),
            # x_all.reshape(-1, 200, 3).cuda(),
            normalize_points_torch(x_all.reshape(-1, 2048, 3).cuda()),
            return_loss = True,
            return_recons = True,
            return_detailed_loss = True,
            temp = temp
        )
        cd_loss_list.append(cd_loss.detach().cpu().numpy())
        emd_loss_list.append(emd_loss.detach().cpu().numpy())
        vq_loss_list.append(vq_loss.detach().cpu().numpy())
        perplexity_list.append(perplexity.detach().cpu().numpy())
        ori_pts.append(normalize_points_torch(gt_all.reshape(-1, 2048, 3)).cpu().clone())
        our_pts.append(recons.reshape(-1, 2048, 3).detach().cpu().clone())
        for j in range(recons.shape[0]):
            count+=1
            if count < args.save_img_num:
                save_ply(os.path.join(save_dir,'%04d'%count+'_ori.ply'), images[j])
                save_ply(os.path.join(save_dir,'%04d'%count+'_recon.ply'), recons[j].reshape(-1,3))
                render_pts_with_label(os.path.join(save_dir, '%04d'%count+'_ori.png'), rotate_pts(images[j].cpu().numpy(), 100, 15),  torch.randint(3,[images[j].shape[0],1]).numpy())
                render_pts_with_label(os.path.join(save_dir, '%04d'%count+'_recon.png'), rotate_pts(recons[j].reshape(-1,3).detach().cpu().numpy(), 100, 12),  torch.randint(3,[recons[j].reshape(-1,3).shape[0],1]).numpy(), point_size=16)
embed()


print('total_num:%d, cd_loss:%.4f emd_loss:%.4f vq_loss:%.4f perplexity:%.2f'%(count, np.mean(np.array(cd_loss_list)), np.mean(np.array(emd_loss_list)), np.mean(np.array(vq_loss_list)),np.mean(np.array(perplexity_list))))
torch.save(our_pts, os.path.join('./outputs/shapenet_results', 'our_pts_'+args.save_name +'.pt'))
torch.save(ori_pts, os.path.join('./outputs/shapenet_results', 'ori_pts_'+args.save_name +'.pt'))
ori_pts = torch.cat(ori_pts)
our_pts = torch.cat(our_pts)
cd_dis = []
emd_dis = []
bs = 256

if eval_reconstruciton:
    for i in range(int(ori_pts.shape[0])):
        rec_res = EMD_CD(ori_pts[i].unsqueeze(0).repeat([20, 1, 1]).cuda(), our_pts[i*20:(i+1)*20].cuda(), 20, reduced=False)
        print(i, 'CD:', rec_res['MMD-CD'].mean(), 'EMD:', rec_res['MMD-EMD'].mean())
        cd_dis.append(rec_res['MMD-CD'])
        emd_dis.append(rec_res['MMD-EMD'])
else:
    print('eval_reconstruciton is false')
    embed()

embed()
exit()

for i in range(int(ori_pts.shape[0]/bs)):
    cd_dis.append(chamfer_distance(our_pts[i*bs:(i+1)*bs].cuda(), ori_pts[i*bs:(i+1)*bs].cuda())[0])
final_cd_dis = torch.mean(torch.Tensor(cd_dis))
print('cd_dis:', final_cd_dis)
# tensor(0.0064), model 102.ckpt
# cd_dis: tensor(0.0055), model 99.ckpt
# cd_dis: tensor(0.0046), model 100.ckpt
if args.emd:
    emd_dis = []
    dim = 2048
    for i in range(our_pts.shape[0]):
        q1 = our_pts[i]
        q2 = ori_pts[i]
        t1 = np.repeat(q1,dim,axis=0).reshape(dim,dim,3)
        t2 = np.swapaxes(np.repeat(q2,dim,axis=0).reshape(dim,dim,3), 0, 1)
        diff = t1-t2
        matrix = diff[:,:,0]*diff[:,:,0]+diff[:,:,1]*diff[:,:,1]+diff[:,:,2]*diff[:,:,2]
        row_ind, col_ind = linear_sum_assignment(matrix)
        diff2=q1 - q2[col_ind]
        # diff2 = q1 - q2
        cur_emd_dis = torch.mean(diff2[:,0]*diff2[:,0]+diff2[:,1]*diff2[:,1]+diff2[:,2]*diff2[:,2])
        print('emd', i, cur_emd_dis)
        emd_dis.append(cur_emd_dis)
    
embed()

    # if using_deepspeed:
    #     # Gradients are automatically zeroed after the step
    #     distr_vae.backward(loss)
    #     distr_vae.step()
    # else:
    #     distr_opt.zero_grad()
    #     loss.backward()
    #     distr_opt.step()
    # if not using_deepspeed_sched:
    #     distr_sched.step()

    # logs = {}
    # if args.testae:
    #     for j in range(images.shape[0]):
    #         save_ply(os.path.join(save_dir,'%04d'%count+'_ori.ply'), images[j])
    #         save_ply(os.path.join(save_dir,'%04d'%count+'_recons.ply'), recons[j].reshape(-1,3))
    #     count+=1

    # # if i % 100 == 0:
    # else:
    #    if distr_backend.is_root_worker():
    #        k = NUM_IMAGES_SAVE

    #        with torch.no_grad():
    #            codes = vae.get_codebook_indices(images[:k])
    #            hard_recons = vae.decode(codes)

    #        images, recons = map(lambda t: t[:k], (images, recons))
    #        images, recons, hard_recons, codes = map(lambda t: t.detach().cpu(), (images, recons, hard_recons, codes))
    #        # images, recons, hard_recons = map(lambda t: make_grid(t.float(), nrow = int(sqrt(k)), normalize = True, range = (-1, 1)), (images, recons, hard_recons))
    #        for j in range(images.shape[0]):
    #            save_ply(os.path.join(save_dir,'%04d'%count+'_ori.ply'), images[j])
    #            save_ply(os.path.join(save_dir,'%04d'%count+'_recons.ply'), recons[j].reshape(-1,3))
    #            save_ply(os.path.join(save_dir,'%04d'%count+'_hardrecons.ply'), hard_recons[j].reshape(-1,3))
    #            # save_ply(os.path.join('./vae_outputs','test'+args.save_name,'%04d'%count+'_our.ply'), recons[j].reshape(10000,3))
    #            count+=1